import os
import argparse
import torch
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import re
from collections import defaultdict
from evaluator import BiasEvaluator
import pandas as pd
from glob import glob
from tqdm import tqdm
from pprint import pprint
import os



def main(args):

    if args.iqa=='1':
        evaluator = BiasEvaluator(args,device)
        annotation = pd.read_csv(f"generated_image_list.csv")

        class_list = list(annotation['profession'].unique(  ))
        race_list = list(annotation['race'].unique())
        for profession in class_list:
            sub_df = annotation[annotation['profession']==profession]
            for gender in ['male','female']:
                sub_df2 = sub_df[sub_df['file_gender']==gender]
                for race in race_list:
                    sub_df3 = sub_df2[sub_df2['race']==race]
                    for category in range(1,7):
                        sub_df4 = sub_df3[sub_df3['category']==category]
                        file_path = [path for path in list(sub_df4['file_path'])]
                        evaluator.extract_sdxl_IQA1(file_path, profession, gender, race, category)
        results = []
    
        for profession in tqdm(class_list):
            print(f'Start class {profession}')
            class_result = {
                        'profession': profession}
            models  = ['mae','msn','inception','vgg','resnet','vit_1k','vit_21k','swin_1k','swin_21k','moco','dino_resnet','dino_vit','clip_vit','clip_resnet']
            for model_name in tqdm(models):
                print(f'Start model {model_name}')
                gender_DQA1,gender_DQA2,gender_DQA3,gender_DQA4,gender_DQA5,gender_DQA6 = evaluator.run_sdxl_DQA(profession,model_name,race_list, mode='gender')
                race_DQA1,race_DQA2,race_DQA3,race_DQA4,race_DQA5,race_DQA6 = evaluator.run_sdxl_DQA(profession,model_name,race_list, mode='race')

                class_result[f'gender_DQA1_{model_name}']=gender_DQA1.item() if hasattr(gender_DQA1, 'item') else gender_DQA1
                class_result[f'gender_DQA2_{model_name}']=gender_DQA2.item() if hasattr(gender_DQA2, 'item') else gender_DQA2
                class_result[f'gender_DQA3_{model_name}']=gender_DQA3.item() if hasattr(gender_DQA3, 'item') else gender_DQA3
                class_result[f'gender_DQA4_{model_name}']=gender_DQA4.item() if hasattr(gender_DQA4, 'item') else gender_DQA4
                class_result[f'gender_DQA5_{model_name}']=gender_DQA5.item() if hasattr(gender_DQA5, 'item') else gender_DQA5
                class_result[f'gender_DQA6_{model_name}']=gender_DQA6.item() if hasattr(gender_DQA6, 'item') else gender_DQA6
                class_result[f'race_DQA1_{model_name}']=race_DQA1.item() if hasattr(race_DQA1, 'item') else race_DQA1
                class_result[f'race_DQA2_{model_name}']=race_DQA2.item() if hasattr(race_DQA2, 'item') else race_DQA2
                class_result[f'race_DQA3_{model_name}']=race_DQA3.item() if hasattr(race_DQA3, 'item') else race_DQA3
                class_result[f'race_DQA4_{model_name}']=race_DQA4.item() if hasattr(race_DQA4, 'item') else race_DQA4
                class_result[f'race_DQA5_{model_name}']=race_DQA5.item() if hasattr(race_DQA5, 'item') else race_DQA5
                class_result[f'race_DQA6_{model_name}']=race_DQA6.item() if hasattr(race_DQA6, 'item') else race_DQA6
                
            results.append(class_result)
    else:
        evaluator = BiasEvaluator(args,device)
        annotation = pd.read_csv(f"generated_image_list.csv")

        class_list = list(annotation['profession'].unique(  ))
        race_list = list(annotation['race'].unique())
        for profession in class_list:
            sub_df = annotation[annotation['profession']==profession]
            for gender in ['male','female']:
                sub_df2 = sub_df[sub_df['file_gender']==gender]
                for race in race_list:
                    sub_df3 = sub_df2[sub_df2['race']==race]
                    for category in range(1,7):
                        sub_df4 = sub_df3[sub_df3['category']==category]
                        file_path = [os.path.join('diffusion-perturbations', path) for path in list(sub_df4['file_path'])]
                        if args.iqa=='2':
                            evaluator.extract_IQA_2(file_path, profession, gender, race, category)
                        elif args.iqa=='3':
                            evaluator.extract_IQA_3(file_path, profession, gender, race, category)
        results = []
    
        if args.iqa=='2':
            models= ['blip_prompt1','blip_prompt2','paligemma_prompt1','paligemma_prompt2']
            for profession in tqdm(class_list):
                class_result = {
                        'profession': profession}
            
                for model_name in tqdm(models):
                    gender_DQA1, gender_DQA2, gender_DQA3, gender_DQA4, gender_DQA5, gender_DQA6 = evaluator.DQA_IQA2_with_bootstrap(profession, model_name,race_list,mode='gender')
                    race_DQA1, race_DQA2, race_DQA3, race_DQA4, race_DQA5, race_DQA6 = evaluator.DQA_IQA2_with_bootstrap(profession, model_name,race_list,mode='race')
                    class_result[f'gender_DQA1_{model_name}']=gender_DQA1.item() if hasattr(gender_DQA1, 'item') else gender_DQA1
                    class_result[f'gender_DQA2_{model_name}']=gender_DQA2.item() if hasattr(gender_DQA2, 'item') else gender_DQA2
                    class_result[f'gender_DQA3_{model_name}']=gender_DQA3.item() if hasattr(gender_DQA3, 'item') else gender_DQA3
                    class_result[f'gender_DQA4_{model_name}']=gender_DQA4.item() if hasattr(gender_DQA4, 'item') else gender_DQA4
                    class_result[f'gender_DQA5_{model_name}']=gender_DQA5.item() if hasattr(gender_DQA5, 'item') else gender_DQA5
                    class_result[f'gender_DQA6_{model_name}']=gender_DQA6.item() if hasattr(gender_DQA6, 'item') else gender_DQA6
                    class_result[f'race_DQA1_{model_name}']=race_DQA1.item() if hasattr(race_DQA1, 'item') else race_DQA1
                    class_result[f'race_DQA2_{model_name}']=race_DQA2.item() if hasattr(race_DQA2, 'item') else race_DQA2
                    class_result[f'race_DQA3_{model_name}']=race_DQA3.item() if hasattr(race_DQA3, 'item') else race_DQA3
                    class_result[f'race_DQA4_{model_name}']=race_DQA4.item() if hasattr(race_DQA4, 'item') else race_DQA4
                    class_result[f'race_DQA5_{model_name}']=race_DQA5.item() if hasattr(race_DQA5, 'item') else race_DQA5
                    class_result[f'race_DQA6_{model_name}']=race_DQA6.item() if hasattr(race_DQA6, 'item') else race_DQA6
                results.append(class_result)
            
        elif args.iqa=='3':
            models= ['IQA3_faces','IQA3_flive']
            for profession in tqdm(class_list):
                class_result = {
                        'profession': profession}
            
                for model_name in tqdm(models):
                    gender_DQA1, gender_DQA2, gender_DQA3, gender_DQA4, gender_DQA5, gender_DQA6 = evaluator.DQA_IQA3_with_bootstrap(profession, model_name,race_list,mode='gender')
                    race_DQA1, race_DQA2, race_DQA3, race_DQA4, race_DQA5, race_DQA6 = evaluator.DQA_IQA3_with_bootstrap(profession, model_name,race_list,mode='race')
                    class_result[f'gender_DQA1_{model_name}']=gender_DQA1.item() if hasattr(gender_DQA1, 'item') else gender_DQA1
                    class_result[f'gender_DQA2_{model_name}']=gender_DQA2.item() if hasattr(gender_DQA2, 'item') else gender_DQA2
                    class_result[f'gender_DQA3_{model_name}']=gender_DQA3.item() if hasattr(gender_DQA3, 'item') else gender_DQA3
                    class_result[f'gender_DQA4_{model_name}']=gender_DQA4.item() if hasattr(gender_DQA4, 'item') else gender_DQA4
                    class_result[f'gender_DQA5_{model_name}']=gender_DQA5.item() if hasattr(gender_DQA5, 'item') else gender_DQA5
                    class_result[f'gender_DQA6_{model_name}']=gender_DQA6.item() if hasattr(gender_DQA6, 'item') else gender_DQA6
                    class_result[f'race_DQA1_{model_name}']=race_DQA1.item() if hasattr(race_DQA1, 'item') else race_DQA1
                    class_result[f'race_DQA2_{model_name}']=race_DQA2.item() if hasattr(race_DQA2, 'item') else race_DQA2
                    class_result[f'race_DQA3_{model_name}']=race_DQA3.item() if hasattr(race_DQA3, 'item') else race_DQA3
                    class_result[f'race_DQA4_{model_name}']=race_DQA4.item() if hasattr(race_DQA4, 'item') else race_DQA4
                    class_result[f'race_DQA5_{model_name}']=race_DQA5.item() if hasattr(race_DQA5, 'item') else race_DQA5
                    class_result[f'race_DQA6_{model_name}']=race_DQA6.item() if hasattr(race_DQA6, 'item') else race_DQA6
                results.append(class_result)
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'result/sdxl_{args.iqa}.csv', index=False)
            
if __name__ == '__main__':
    
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser()
    parser.add_argument('--iqa', type=str,default='1')
    parser.add_argument('--root_dir', type = str, default='./')
    parser.add_argument('--gpu_id', type=int, default=3)

    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    device = "cuda" if torch.cuda.is_available() else "cpu"  # Use the remapped device index
    main(args)
